import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib
import torchvision
import random
from torch.utils.data import Subset
import torch.fft

from tqdm import tqdm
#from tensorflow.keras.datasets import mnist

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models

def add_high_frequency_noise(image, noise_level=0.1):
    """
    Add noise to the high-frequency Fourier components of an image.

    Args:
        image (torch.Tensor): Input image with shape (C, H, W).
        noise_level (float): Intensity of the added noise.

    Returns:
        torch.Tensor: Image with added noise.
    """
    # Transform image from time domain to frequency domain
    image_fft = torch.fft.fft2(image)
    image_fft_shifted = torch.fft.fftshift(image_fft)

    # Get center coordinates
    _, height, width = image.shape
    center_h, center_w = height // 2, width // 2

    # Define high-frequency range (e.g., exclude central region)
    high_freq_mask = torch.ones_like(image_fft_shifted, dtype=torch.bool)
    radius = min(height, width) // 4  # Radius can be adjusted as needed
    for h in range(height):
        for w in range(width):
            if (h - center_h) ** 2 + (w - center_w) ** 2 < radius ** 2:
                high_freq_mask[:, h, w] = False

    # Add random noise to high-frequency components
    noise = torch.randn_like(image_fft_shifted) * noise_level
    image_fft_shifted[high_freq_mask] += noise[high_freq_mask]

    # Inverse transform back to time domain
    image_fft_shifted = torch.fft.ifftshift(image_fft_shifted)
    noisy_image = torch.fft.ifft2(image_fft_shifted).real

    return noisy_image

# Modify data loading function to add noise
class AddNoiseTransform:
    def __init__(self, base_transform, noise_level=0.1):
        self.base_transform = base_transform
        self.noise_level = noise_level

    def __call__(self, image):
        # Apply base transformation (e.g., normalization)
        transformed_image = self.base_transform(image)
        # Add high-frequency noise
        noisy_image = add_high_frequency_noise(transformed_image, self.noise_level)
        return noisy_image

# Data preprocessing
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761)),
])

# Modify training set to add high-frequency noise
transform_train_with_noise = AddNoiseTransform(transform_train, noise_level=40.0)
train_dataset_with_noise = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train_with_noise
)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Select 10,000 images from the training set, 1,000 per class
def create_balanced_subset(dataset, num_per_class=100, seed=42):
    np.random.seed(seed)  # Fix random seed
    targets = np.array(dataset.targets)  # CIFAR-10 labels
    selected_indices = []

    for class_id in range(10):  # Iterate through 10 classes
        class_indices = np.where(targets == class_id)[0]  # Get all indices of the current class
        selected_class_indices = np.random.choice(class_indices, num_per_class, replace=False)  # Randomly select samples
        selected_indices.extend(selected_class_indices)

    return Subset(dataset, selected_indices)

balanced_train_dataset = create_balanced_subset(train_dataset_with_noise, num_per_class=100)

train_loader = torch.utils.data.DataLoader(balanced_train_dataset, batch_size=10000, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10000, shuffle=False, num_workers=2)

# Define VGG model
class VGGNet(nn.Module):
    def __init__(self, num_classes=10):
        super(VGGNet, self).__init__()
        self.vgg = models.vgg16(pretrained=True)  # Use VGG16 pre-trained model
        self.vgg.classifier[6] = nn.Linear(4096, num_classes)  # Modify the last fully connected layer

    def forward(self, x):
        return self.vgg(x)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(torch.cuda.is_available())
model = VGGNet(num_classes=10).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.05)

# Function to randomly flip labels with probability

def randomize_labels_with_prob(targets, num_classes=10, flip_prob=0.2):
    randomized_targets = targets.clone()
    for i in range(len(randomized_targets)):
        if random.random() < flip_prob:  # Flip label with probability flip_prob
            original_label = randomized_targets[i].item()
            possible_labels = list(range(num_classes))
            possible_labels.remove(original_label)  # Remove the original label
            randomized_targets[i] = random.choice(possible_labels)  # Randomly choose a new label
    return randomized_targets

# Train function
# Store training loss and test accuracy history
train_loss_history = []
test_accuracy_history = []

def train(epoch):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        # Randomly flip labels
        randomized_targets = randomize_labels_with_prob(targets, flip_prob=0.2)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, randomized_targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        avg_loss = total_loss / len(train_loader)
        train_loss_history.append(avg_loss)
        if batch_idx % 100 == 0:
            print(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.3f}, Acc: {100.*correct/total:.3f}")

# Test function
def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    accuracy = 100. * correct / total
    test_accuracy_history.append(accuracy)
    print(f"Test Accuracy: {100.*correct/total:.3f}")

# Train and test the model
for epoch in range(1, 5001):  # Train for 5000 epochs
    train(epoch)
    test()

# Save model and training history
torch.save({
    'model_state_dict': model.state_dict(),
    'train_loss': train_loss_history,
    'test_accuracy': test_accuracy_history
}, 'vgg2_cifar10_lngd2_lrd05_snr40.pth')


